import warnings
import os
import configparser
import wandb  # Add this import
from tqdm import tqdm  # Change this import
# 在导入任何库之前设置警告抑制
warnings.filterwarnings('ignore')
# 设置环境变量
os.environ['PYTHONWARNINGS'] = 'ignore'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["OMP_NUM_THREADS"] = "1"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 重定向标准错误输出
import sys
sys.stderr = open(os.devnull, 'w')
config = configparser.ConfigParser()
config.read(f'{os.path.dirname(os.path.abspath(__file__))}/global_config.ini')
project_path = config['simulation']['project_path']
# 添加项目根目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.append(project_root)
import random
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import datetime
from datetime import datetime
from collections import defaultdict
import yaml
import omnigibson as og
import json
import torch as th
from omnigibson.macros import gm
import dill
import hydra
sys.path.append(f"{project_path}") 
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
sys.path.append(f"{project_path}/DexDiffusionPolicy")
current_project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # policy_decorator dir
project_dir = os.path.dirname(current_project_dir)  # /home/admin01/project
benchmark_path = os.path.join(project_dir, "Benchmark")
sys.path.append(benchmark_path)
from policy_decorator.utils.profiling import NonOverlappingTimeProfiler
from Benchmark_results.evaluate.env_aug_iter import EnvAugmentor
from omnigibson.tasks.task_base import BaseTask
from copy import deepcopy

# 设置 gm headledd
gm.HEADLESS = False


# 保留这一行作为默认配置文件路径
RL_TRAIN_CONFIG_PATH = f"{project_path}/path/to/policy_config.yaml"

ball_has_life = False
reward_stage_flag = 0
obj_start_pos = [0,0,0.98]
obj_start_ori = [0,0,0]
done_steps = 0

scale_list = [[0.003,0.001,0.0015],
              [0.0033,0.001,0.0015],
              [0.0027,0.001,0.0015],
              [0.003,0.0013,0.0015],
              [0.003,0.0007,0.0015],
              [0.003,0.001,0.0012],
              ]

# scale_list = [[0.003,0.001,0.0015],
#               [0.0035,0.001,0.0015],
#               [0.002,0.001,0.0015],
#               [0.003,0.0015,0.0015],
#               [0.003,0.0005,0.0015],
#               [0.003,0.001,0.002],
#               [0.003,0.001,0.0005]
# ]

index = 0

def change_scale(env,obj):
    import omni.kit.commands
    from pxr import Sdf,Gf
    # pass
    global index
    # index = torch.randint(0,len(scale_list),size=(1,))[0]
    index = (index + 1)%len(env.scale_list)

    env.obj_scale = env.scale_list[index]
    # og.sim.stage.RemovePrim("/World/collision_groups")
    omni.kit.commands.execute(
        "ChangeProperty", prop_path=Sdf.Path(f"{obj._prim.GetAllChildren()[0].GetPath()}.xformOp:scale"), value=Gf.Vec3d(*env.obj_scale), prev=None
    )
    bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(
            visual=False
        )
    env.bbox_extent_in_base_frame = bbox_extent_in_base_frame
    # obj._prim.GetAllChildren()[0].GetAttribute("xformOp:scale").Set(Gf.Vec3d(*env.obj_scale))
    # print(obj._prim.GetAllChildren()[0].GetAttribute("xformOp:scale").Get())
    # og.sim.stop()
    # obj_scale = th.tensor(scale_list[index])
    # obj.scale = th.tensor(env.obj_scale)
    # og.sim.play()
    # for _ in range(100000):
    #     og.sim.step()



class MyTask(BaseTask):

    def _step_reward(self, env, action, info=None):
        global reward_stage_flag
        global reward_stage_flag2_1_steps
        global reward_stage_flag2_2_steps

        ############# reward设计
        # print("reward design")
        
        def _sigmoids(x, value_at_1, sigmoid):
            """Returns 1 when `x` == 0, between 0 and 1 otherwise.

            Args:
                x: A scalar or PyTorch tensor of shape (batch_size, 1).
                value_at_1: A float between 0 and 1 specifying the output when `x` == 1.
                sigmoid: String, choice of sigmoid type.

            Returns:
                A PyTorch tensor with values between 0.0 and 1.0.

            Raises:
                ValueError: If not 0 < `value_at_1` < 1, except for `linear`, `cosine` and
                `quadratic` sigmoids which allow `value_at_1` == 0.
                ValueError: If `sigmoid` is of an unknown type.
            """
            if sigmoid in ('cosine', 'linear', 'quadratic'):
                if not 0 <= value_at_1 < 1:
                    raise ValueError('`value_at_1` must be nonnegative and smaller than 1, '
                                    'got {}.'.format(value_at_1))
            else:
                if not 0 < value_at_1 < 1:
                    raise ValueError('`value_at_1` must be strictly between 0 and 1, '
                                    'got {}.'.format(value_at_1))

            if sigmoid == 'gaussian':
                scale = torch.sqrt(-2 * torch.log(torch.tensor(value_at_1)))
                return torch.exp(-0.5 * (x * scale) ** 2)

            elif sigmoid == 'hyperbolic':
                scale = torch.acosh(1 / torch.tensor(value_at_1))
                return 1 / torch.cosh(x * scale)

            elif sigmoid == 'long_tail':
                scale = torch.sqrt(1 / torch.tensor(value_at_1) - 1)
                return 1 / ((x * scale) ** 2 + 1)

            elif sigmoid == 'reciprocal':
                scale = 1 / torch.tensor(value_at_1) - 1
                return 1 / (torch.abs(x) * scale + 1)

            elif sigmoid == 'cosine':
                scale = torch.acos(2 * torch.tensor(value_at_1) - 1) / torch.pi
                scaled_x = x * scale
                with warnings.catch_warnings():
                    warnings.filterwarnings(
                        action='ignore', message='invalid value encountered in cos')
                    cos_pi_scaled_x = torch.cos(torch.pi * scaled_x)
                return torch.where(torch.abs(scaled_x) < 1, (1 + cos_pi_scaled_x) / 2, torch.tensor(0.0))

            elif sigmoid == 'linear':
                scale = 1 - torch.tensor(value_at_1)
                scaled_x = x * scale
                return torch.where(torch.abs(scaled_x) < 1, 1 - scaled_x, torch.tensor(0.0))

            elif sigmoid == 'quadratic':
                scale = torch.sqrt(1 - torch.tensor(value_at_1))
                scaled_x = x * scale
                return torch.where(torch.abs(scaled_x) < 1, 1 - scaled_x ** 2, torch.tensor(0.0))

            elif sigmoid == 'tanh_squared':
                scale = torch.atanh(torch.sqrt(1 - torch.tensor(value_at_1)))
                return 1 - torch.tanh(x * scale) ** 2

            else:
                raise ValueError('Unknown sigmoid type {!r}.'.format(sigmoid))

        def tolerance(x, y, r, margin=0.0, sigmoid='gaussian', value_at_margin= 0.1):
            """Returns 1 when `x` falls inside the circle centered at `p` with radius `r`, between 0 and 1 otherwise.

            Args:
                x: A batch_size x 3 numpy array representing the points to check.
                y: A length-3 numpy array representing the center of the circle.
                r: Float. The radius of the circle.
                margin: Float. Parameter that controls how steeply the output decreases as `x` moves out-of-bounds.
                sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian', 'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
                value_at_margin: A float between 0 and 1 specifying the output value when the distance from `x` to the nearest bound is equal to `margin`. Ignored if `margin == 0`.

            Returns:
                A numpy array with values between 0.0 and 1.0 for each point in the batch.

            Raises:
                ValueError: If `margin` is negative.
            """
            if margin < 0:
                raise ValueError('`margin` must be non-negative.')

            # Calculate the Euclidean distance from each point in x to p
            distance = torch.norm(x - y, p=2, dim=-1)

            in_bounds = distance <= r
            if margin == 0:
                value = torch.where(in_bounds, 1.0, 0.0)
            else:
                d = (distance - r) / margin
                
                value = torch.where(in_bounds, 1.0, _sigmoids(d, value_at_margin, sigmoid))

            return value
        
        robot = env.robots[0]
        obj = env.scene.object_registry("name", "ball0")
        obj_pos = torch.tensor(obj.get_position())

        bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(
            visual=False
        )

        # 上升距离
        target_lift_height = 0.16 # 预期抬升高度是0.2米
        r_lift = (obj_pos[2] - obj_start_pos[2]) / (target_lift_height)
        r_lift = np.clip(r_lift, 0.0, 1.0)
        # 右手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
        right_hand_thumb_pos = robot._links["hand2_link_1_3"].get_position() #右手拇指的位置
        right_hand_index_pos = robot._links["hand2_link_2_2"].get_position() #右手食指的位置
        right_hand_middle_pos = robot._links["hand2_link_3_2"].get_position() #右手中指的位置
        right_hand_thumb_pos_1 = robot._links["hand2_link_1_1"].get_position()
        right_hand_index_pos_1 = robot._links["hand2_link_2_1"].get_position()
        right_hand_middle_pos_1 = robot._links["hand2_link_3_1"].get_position()

        right_fingers_pos = [right_hand_thumb_pos, right_hand_index_pos, right_hand_middle_pos, right_hand_thumb_pos_1, right_hand_index_pos_1, right_hand_middle_pos_1]
        right_current_finger_dist = sum([torch.norm(obj_pos - finger_pos, p=2, dim=-1) for finger_pos in right_fingers_pos])
        right_current_finger_dist_avg = right_current_finger_dist/len(right_fingers_pos)
        # 0.2为抓住的时候，0.6为远离

        right_hand_pos_close = torch.norm(right_hand_index_pos - right_hand_thumb_pos)+torch.norm(right_hand_middle_pos - right_hand_thumb_pos)
        right_hand_pos_close_avg = right_hand_pos_close/2

        
        # 左手：计算当前距离：计算物体到右手三个手指（拇指、食指、中指）距离的总和
        left_hand_thumb_pos = robot._links["hand1_link_1_3"].get_position()
        left_hand_index_pos = robot._links["hand1_link_2_2"].get_position()
        left_hand_middle_pos = robot._links["hand1_link_3_2"].get_position()
        left_hand_thumb_pos_1 = robot._links["hand1_link_1_1"].get_position()
        left_hand_index_pos_1 = robot._links["hand1_link_2_1"].get_position()
        left_hand_middle_pos_1 = robot._links["hand1_link_3_1"].get_position()

        left_fingers_pos = [left_hand_thumb_pos, left_hand_index_pos, left_hand_middle_pos, left_hand_thumb_pos_1, left_hand_index_pos_1, left_hand_middle_pos_1]
        left_current_finger_dist = sum([torch.norm(obj_pos - finger_pos, p=2, dim=-1) for finger_pos in left_fingers_pos])
        left_current_finger_dist_avg = left_current_finger_dist/len(left_fingers_pos)

        left_hand_pos_close = torch.norm(left_hand_index_pos - left_hand_thumb_pos)+torch.norm(left_hand_middle_pos - left_hand_thumb_pos)
        left_hand_pos_close_avg = left_hand_pos_close/2
        left_hand_thumb_index_2_center_pos = (left_hand_thumb_pos_1 + left_hand_index_pos_1)/2
        # 抓住的avg距离为 0.3455
        # 大拇指和食指的 pos x 符号相反
        box_in_lef_hand = (torch.sign(left_hand_thumb_pos[0]-obj_pos[0])*torch.sign(left_hand_index_pos[0]-obj_pos[0])<0)

        #  Make sure info is a dict
        total_info = dict() if info is None else info
        # Aggregate rewards over all reward functions
        total_reward = torch.tensor(0.0)
        r1 = torch.tensor(0.0)
        r2 = torch.tensor(0.0)
        r3 = torch.tensor(0.0)

        r_right_close = torch.tensor(-1)
        r_left_close = torch.tensor(-1)
        r_right_loose = torch.tensor(-1)

        # 阶段判断
        if reward_stage_flag==0 and abs(obj_pos[2] - obj_start_pos[2]) > 0.15: #20cm
            reward_stage_flag = 1

        # if reward_stage_flag==1 and left_current_finger_dist.mean()<0.2 : # 左手到物体的距离小于5cm,代表抓住了物体
        #     reward_stage_flag2_1_steps += 1
        #     if reward_stage_flag2_1_steps>=5:
        #         reward_stage_flag=2
        # else:
        #     reward_stage_flag2_1_steps = 0 # 关闭是 left_hand_pos_close_avg = 0.08...
        # finger dist 远离是 1/5=0.2 远离， 0.1是靠近；在边缘并拢为 0.1777 ;# hand _close 完全关上是 0.05
        # #在 侧面 0.1465 也错误算作成功 

        # 添加及时回退阶段
        if reward_stage_flag==2 and not(box_in_lef_hand and left_current_finger_dist_avg<0.09 and \
            left_hand_pos_close_avg<0.08 and left_hand_pos_close_avg >= bbox_extent_in_base_frame[2]) and \
            not(box_in_lef_hand and left_current_finger_dist_avg<0.12 and right_current_finger_dist>0.35 \
            and left_hand_pos_close_avg<0.09 and obj_pos[2] - obj_start_pos[2]>0.15):
            reward_stage_flag=1


        if reward_stage_flag==1 and box_in_lef_hand and left_current_finger_dist_avg<0.09 and \
            left_hand_pos_close_avg<0.08 and left_hand_pos_close_avg >= bbox_extent_in_base_frame[2]: # 左手到物体的距离小于5cm,代表抓住了物体
            reward_stage_flag2_1_steps += 1
            if reward_stage_flag2_1_steps>=10:
                reward_stage_flag=2
        else:
            reward_stage_flag2_1_steps = 0
        
        
        if reward_stage_flag==1 and box_in_lef_hand and left_current_finger_dist_avg<0.12 and right_current_finger_dist>0.35 \
            and left_hand_pos_close_avg<0.09 and obj_pos[2] - obj_start_pos[2]>0.15: # 右手没有拿住，持续10steps
            reward_stage_flag2_2_steps += 1
            if reward_stage_flag2_2_steps>=5:
                reward_stage_flag=2
        else:
            reward_stage_flag2_2_steps = 0


        r_right_close = - torch.exp(5  * th.clip((right_current_finger_dist_avg-0.08), 0.0, None))+2
        r_right_close = np.clip(r_right_close, 0.0, 1.0)
        # 右手要并拢
        r_right_close_avg = - torch.exp(0.5  * th.clip((right_hand_pos_close_avg-bbox_extent_in_base_frame[2]), 0.0, None))+2
        r_right_close_avg = np.clip(r_right_close_avg, 0.0, 1.0)

        # r_left_close = torch.exp(-0.5  * th.clip(left_current_finger_dist_avg, 0.0, None)) # 0-1
        r_left_close = - torch.exp(2  * th.clip((left_current_finger_dist_avg-0.01), 0.0, None))+2 #0.05
        r_left_close = np.clip(r_left_close, 0.0, 1.0) 

        # 阶段判断
        if reward_stage_flag==0:
            # ================ 第一阶段：右手抓取 =================== 0-1
            r1 = torch.tensor(0.4) * r_right_close_avg+ torch.tensor(0.3) * r_right_close + torch.tensor(0.3) * r_lift
            # ================ 第一阶段：右手抓取 ===================

        elif reward_stage_flag==1:
            # ================ 第二阶段：左手抓取 =================== 1-2
            # 此时第一阶段的 r 要给满
            # 鼓励高度保持，左手要抓住，右手逐渐松开
            # r1 = torch.tensor(0.3) + torch.tensor(0.7) * r_lift # 0-1
            # r1 = torch.tensor(0.7) * r_right_close + torch.tensor(0.3) * r_lift
            r1 = torch.tensor(0.4) * r_right_close_avg+ torch.tensor(0.3) * r_right_close + torch.tensor(0.3) * r_lift
            # 0.34-0.2 的时候为抓住此时为1  ,0.14/0.34=0.41 写为 0.5
            # r2 = r_left_close
            # 1. 如果左手手指直接关闭了，就要惩罚，hiking
            r_penalty = torch.exp(-20* th.clamp(torch.tensor(left_hand_pos_close_avg - bbox_extent_in_base_frame[2]), None, 0))-1
            r_penalty = th.clamp(r_penalty, 0.0, 1.0)
            # r2 = r_left_close - 0.3*r_penalty

            # 2. 左手的虎口如果在物体左边了，也要惩罚，hiking
            width_half = bbox_extent_in_base_frame[1]/2
            obj_left_y = obj_pos[1] + width_half*1.1
            r_penalty_2 = th.clamp((obj_left_y-left_hand_thumb_index_2_center_pos[1])/width_half*1.1, 0.0, 1.0)

            # 3.左手的虎口如果在物体下方区域，也要惩罚，hiking
            height_half = bbox_extent_in_base_frame[0]/2
            obj_left_z = obj_pos[2] - height_half
            r_penalty_3 = th.clamp((obj_left_z-left_hand_thumb_index_2_center_pos[2])/height_half, 0.0, 1.0)
            r2 = r_left_close - 0.5*r_penalty - 0.5*r_penalty_2 - 0.5*r_penalty_3   

            if not box_in_lef_hand:
                r2 = r2/2
            r2 = np.clip(r2, 0.0, 1.0) 

        elif reward_stage_flag==2:
            # ================ 第三阶段：交接 =================== 2-3
            r1 = torch.tensor(0.7) + torch.tensor(0.3) * r_lift # 0-1

            # 1. 如果左手手指直接关闭了，就要惩罚，hiking
            r_penalty = torch.exp(-20* th.clamp(torch.tensor(left_hand_pos_close_avg - bbox_extent_in_base_frame[2]), None, 0))-1
            r_penalty = th.clamp(r_penalty, 0.0, 1.0)
            # r2 = r_left_close - 0.3*r_penalty

            # 2. 左手的虎口如果在物体里面了，也要惩罚，hiking
            width_half = bbox_extent_in_base_frame[1]/2
            obj_left_y = obj_pos[1] + width_half*1.1
            r_penalty_2 = th.clamp((obj_left_y-left_hand_thumb_index_2_center_pos[1])/width_half*1.1, 0.0, 1.0)

            # 3.左手的虎口如果在物体下方区域，也要惩罚，hiking
            height_half = bbox_extent_in_base_frame[0]/2
            obj_left_z = obj_pos[2] - height_half
            r_penalty_3 = th.clamp((obj_left_z-left_hand_thumb_index_2_center_pos[2])/height_half, 0.0, 1.0)
            r2 = r_left_close - 0.5*r_penalty - 0.5*r_penalty_2 - 0.5*r_penalty_3   

            if not box_in_lef_hand:
                r2 = r2/2
            r2 = np.clip(r2, 0.0, 1.0) 

            r_right_loose = torch.exp(5  * th.clamp(right_current_finger_dist_avg-0.05,0,None))-1 # 0-1
            r_right_loose = np.clip(r_right_loose, 0.0, 1.0)

            # 右手 hand 打开角度越大 reward 越大  
            # 计算拇指到 食指和中指的距离
            right_hand_thumb_index_dist = torch.norm(robot._links["hand2_link_1_3"].get_position() - robot._links["hand2_link_2_2"].get_position())
            right_hand_thumb_middle_dist = torch.norm(robot._links["hand2_link_1_3"].get_position() - robot._links["hand2_link_3_2"].get_position())

            # 计算拇指到 食指和中指的距离
            right_hand_open_dis = (right_hand_thumb_index_dist + right_hand_thumb_middle_dist)/2
            # 0.1 是关上，
            
            # 将距离从 [0.08, 0.09] 归一化到 [0, 1]
            min_dis = 0.04
            max_dis = 0.2
            normalized_dis = (right_hand_open_dis - min_dis) / (max_dis - min_dis)
            # r_right_hand_open = np.clip(normalized_dis, 0.0, 1.0)
            # 关上是 0.05 打开是 0.08  (0,0.39)   #e(0.09-0.4)
            r_right_hand_open = -torch.exp(-5  * th.tensor(normalized_dis))+1

            # 计算右手张开奖励，距离越大奖励越大
            # r_right_hand_open = normalized_dis  # 直接使用归一化后的距离作为奖励

            r3 = 1.2*r_right_loose + 0.8*r_right_hand_open  # 将张开奖励加入到r3中
           

        total_reward = r1 + r2 + r3
    
        if "reward_breakdown" not in total_info:
            total_info["reward_breakdown"] = {}

        total_info["reward_breakdown"].update({
            "total_reward": total_reward.item(),
            "r1": r1.item(),
            "r2": r2.item(),
            "r3": r3.item(),
            "r_lift": r_lift.item(),
            "r_right_close": r_right_close.item(),
            "r_left_close": r_left_close.item(),
            "r_right_loose": r_right_loose.item(),
            "stage": reward_stage_flag,
        })
        return total_reward, total_info


BaseTask._step_reward = MyTask._step_reward

left_hand_init_1 = [0.4378,  0.2352,  1.0964, -1.3006,  1.1194, -1.1886]
right_hand_init_1_0 =  [obj_start_pos[0]+0.41,  obj_start_pos[1]-0.19,  obj_start_pos[2]+0.06, -0.408,  1.7137,  0.8488]
right_hand_init_1 =  [obj_start_pos[0]+0.41,  obj_start_pos[1]-0.18,  obj_start_pos[2]+0.06, -0.408,  1.7137,  0.8488]
right_hand_init_2 =  [*right_hand_init_1]
right_hand_init_2[1] += 0.01   # 右手靠近物体 同时抓取
right_hand_init_3 = [*right_hand_init_2]
right_hand_init_3[2] += 0.25    # 右手抬升 #0.25 可以考虑分两阶段提上去，提得太快容易掉
right_hand_init_3[3] += 0.4   # 右手角度变换

joint0 = 1.57
finger_open = [3.14, 1.57] # 90 180
finger_close = [2.2, 1.2] # 0 0 

left_finger_init_1 = [joint0, *finger_open*5]

right_finger_init_1 = [joint0, *finger_open*5]
right_finger_init_2 = [joint0,  *finger_close*5]
        # action[0:5] 左臂末端执行器位姿 x y z 三维轴角
        # action[5:11] 右臂末端执行器位姿 x y z 三维轴角
        # action[11:22] 左手关节【大拇指3,2,2,2,2】  0,1,3,5,7,9
        # action[22:33] 右手关节

class MyEnv(og.Environment):
    scale_list = scale_list
    obj_scale = scale_list[0]
    bbox_extent_in_base_frame = None
    def right_lift_init(self):
        print("right_lift_init")
        # # # 移动到目标位置
        action_ls = [
            np.array([*left_hand_init_1, *right_hand_init_1_0, *left_finger_init_1, *right_finger_init_1]),
            np.array([*left_hand_init_1, *right_hand_init_1, *left_finger_init_1, *right_finger_init_1]),
            # th.tensor([*left_hand_init_1, *right_hand_init_2, *left_finger_init_1, *right_finger_init_1], device='cuda:0'),
            ]
        for _ in range(2):
            for action in action_ls:
                for _ in range(10):
                    self.step(action)

        # 关上爪子
        action_2 = np.array([ *left_hand_init_1, *right_hand_init_2, *left_finger_init_1, *right_finger_init_2])
        for _ in range(30):  # Maximum 100 steps
            self.step(action_2)
        # action_3 = np.array([ 0.4126,  0.2077,  1.1318,  1.7020, -1.6407,  1.9914,  
        #                      0.4460, -0.1737,  1.1361, -0.2105,  1.2280,  0.6486,  
        #                      0.3006,  0.5436, -0.0480,  3.0246, 1.4732,  2.9921,  1.4841,  3.0027,  1.4809,  2.9751, 1.4861, 
        #                      1.5261, 0.2738, -1.6449,  1.7990,  0.1671,  1.7588,  0.0471,  1.7985,  0.1668,  1.7572,  1.0414])
        action_3 = np.array([ *left_hand_init_1,  
                             0.4460, -0.1737,  1.1361, -0.2105,  1.2280,  0.6486,  
                             *left_finger_init_1, 
                             1.5261, 0.2738, -1.6449,  1.7990,  0.1671,  1.7588,  0.0471,  1.7985,  0.1668,  1.7572,  1.0414])
        # 抬升
        # action_3 = np.array([ *left_hand_init_1, *right_hand_init_3, *left_finger_init_1, *right_finger_init_2])
        for _ in range(30):  # Maximum 100 steps
            self.step(action_3)

    def _check_failure(self):
        ############### 失败时reset环境的条件设计
        global ball_has_life
        # obj_start_pos: [0,0,0.98] 抬升至[0,0,1.2]；0.2m=20cm

        # ========= 情况一：抓取后掉落 =========
        # 获取物体和机器人手部的位置
        ball = self.scene.object_registry("name", "ball0")
        ball_pos = ball.get_position()
        if abs(ball_pos[2] - obj_start_pos[2]) > 0.1: #20cm
            ball_has_life = True
        if ball_has_life and abs(ball_pos[2] - obj_start_pos[2]) < 0.02: #0.05
            print("Failed: Object dropped after grasping")
            return True
        #############################################

        # ========= 情况二：长时间抓不起来也判断失败 =========
        if abs(ball_pos[2] - obj_start_pos[2]) < 0.05 and self._current_step > 500:
            print("Failed: Object not grasped after 500 steps")
            return True
        #############################################

        # ========= 情况三：物体发生了一定的旋转且在地上 =========
        # tensor([-0.7743,  0.6321,  0.0095, -0.0287])
        # start = tensor([-0.0062,  0.6899, -0.0062,  0.7239])
        # ball_ori = ball.get_orientation()
        ball_ori = ball.get_position_orientation()[1]
        if abs(np.linalg.norm(ball_ori - obj_start_ori)) > 0.15 and  ball_pos[2] <= obj_start_pos[2]:
            print("Failed: Drop and Rotated")
            return True
        #############################################

        return False

    def _post_step(self,action):
        super()._post_step(action)
        global done_steps
        

        ########## 这个地方是成功率计算指标
        # print("success rate design")
        # Grab observations
        obs, obs_info = self.get_obs()

        # Step the scene graph builder if necessary
        if self._scene_graph_builder is not None:
            self._scene_graph_builder.step(self.scene)

        # Grab reward, done, and info, and populate with internal info
        reward, done, info = self.task.step(self, action)
        #done:如果任务成功，则done=True。任务是否成功从env中判断，可以参考该函数。
        ball = self.scene.object_registry("name", "ball0")
        current_height = ball.get_position()[2]
        height_threshold = 0.15
        initial_height =0.90
        robot = self.robots[0]
        right_hand_pos = robot._links["hand2_link_1_3"].get_position()
        left_hand_pos = robot._links["hand1_link_1_3"].get_position()
        right_dis = th.norm(right_hand_pos - ball.get_position())
        left_dis = th.norm(left_hand_pos - ball.get_position())
        # if current_height - initial_height > height_threshold and right_dis - left_dis > 0.1 and left_dis < 0.1:
        if current_height - initial_height > height_threshold and right_dis > 0.15 and left_dis < 0.1 and reward_stage_flag == 2:
        # if current_height - initial_height > height_threshold and right_dis > 0.13 and left_dis < 0.13:
            done_steps += 1
            if done_steps >= 10:
                done = True
                print("success")
        else:
            done_steps = 0
            done = False

        self._populate_info(info)
        info["obs_info"] = obs_info

        if done and self._automatic_reset:
            # Add lost observation to our information dict, and reset
            info["last_observation"] = obs
            obs = self.reset()

        # Hacky way to check for time limit info to split terminated and truncated
        terminated = False
        truncated = False
        if done==True:
            terminated = True
        # for tc, tc_data in info["done"]["termination_conditions"].items():
        #     if tc_data["done"]:
        #         if tc == "timeout":
        #             truncated = True #截断
        #         else:
        #             terminated = True #终止 成功done
        # assert (terminated or truncated) == done, "Terminated and truncated must match done!"

        # Increment step
        self._current_step += 1
        # if (self._current_step>=600):
        #     truncated = True
        # 初始用了 200
        if (self._current_step>=800):
            truncated = True
        return obs, reward, terminated, truncated, info

        #################################

def env_reset(env, index):
    global obj_start_pos
    global ball_has_life
    global reward_stage_flag

    env.reset()
    # print("reset env ", current_filename[0])  
    global domain_cfg
    # env_augmentor = setup_env_augmentor(default_obj_config_paths[index])
    # env_augmentor.restore2defaultstate(env)
    #env_augmentor.iterate_env_aug(env, deepcopy(domain_cfg['aug'][index]))
    # env_augmentor.adjust_light_intensity(env,3.5)
    env.robots[0].reset()
    env.robots[0].keep_still()
    for _ in range(10):
        og.sim.step()
    move_to_initial_position(env)
    obj = env.scene.object_registry("name", "ball0")
    obj_start_pos=obj.get_position()
    env._task.reset_for_new_task(obj_start_pos) # for reward compute
    change_scale(env,obj)
    env.right_lift_init()
    obj_start_pos=obj_start_pos
    ball_has_life = False
    reward_stage_flag = 0
    # time.sleep(3)

def load_environment(args):
    global obj_start_pos
    global obj_start_ori
    with open(args.paths.env_cfg, "r") as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
    env = MyEnv(configs=cfg["sim"])
    move_to_initial_position(env)
    obj = env.scene.object_registry("name", "ball0")
    obj_start_pos=obj.get_position()
    obj_start_ori = obj.get_orientation()
    bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(
            visual=False
        )
    env.bbox_extent_in_base_frame = bbox_extent_in_base_frame
    env._task.reset_for_new_task(obj_start_pos)
    env.right_lift_init()
    return env

def move_to_initial_position(env):
    pass    

def collect_episode_info(infos, result=None):
    if result is None:
        result = defaultdict(list)
    if "final_info" in infos: # infos is a dict
        indices = np.where(infos["_final_info"])[0] # not all envs are done at the same time
        for i in indices:
            info = infos["final_info"][i] # info is also a dict
            ep = info['episode']
            print(f"global_step={global_step}, ep_return={ep['r'][0]:.2f}, ep_len={ep['l'][0]}, success={info['success']}")
            result['return'].append(ep['r'][0])
            result['len'].append(ep["l"][0])
            result['success'].append(info['success'])
    return result

def is_ms1_env(env_id):
    return 'OpenCabinet' in env_id or 'MoveBucket' in env_id or 'PushChair' in env_id

def process_image(img):
        # 转换为torch tensor并保持float类型
        if isinstance(img, torch.Tensor):
            img = img[..., :3].float()  # 只保留RGB通道
        else:
            img = torch.from_numpy(img[..., :3]).float()  # 转换为tensor并保留RGB通道
        
        # 调整通道顺序 [H, W, C] -> [C, H, W]
        img = img.permute(2, 0, 1)  
        # img = img.permute(2, 1, 0)  # 调整通道顺序
        
        # 检查输入范围并归一化
        max_val = img.max().item()
        is_normalized = max_val <= 1.0
        
        if not is_normalized:
            img = img / 255.0  # 归一化到[0,1]
        
        # 调整尺寸
        img = F.interpolate(
            img.unsqueeze(0),  # 添加batch维度
            size=(224, 224),   # 调整到模型期望的尺寸
            mode='bilinear',
            align_corners=False
        ).squeeze(0)  # 移除batch维度
        
        # print(f"Final processed image shape: {img.shape}")
        return img

def get_obj_state_info(obj):
    """获取物体的状态信息
    
    Args:
        obj: 物体对象
        
    Returns:
        tuple: (状态张量, 状态信息字典)
    """
    bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(
        visual=False
    )
    linear_velocity = obj.get_linear_velocity()
    angular_velocity = obj.get_angular_velocity()
    
    obj_state_info = {
        "bbox_center_in_world": bbox_center_in_world,
        "bbox_quat_in_world": bbox_quat_in_world,
        "linear_velocity": linear_velocity,
        "angular_velocity": angular_velocity
    }
    
    obj_state_data = th.cat([bbox_center_in_world, bbox_quat_in_world, linear_velocity, angular_velocity], dim=0)
    return obj_state_data, obj_state_info

def preprocess_observation(obs, env):
    """保持在 CPU 上进行预处理
    
    Args:
        obs: 原始观察数据
        env: 环境对象，用于获取物体状态
        
    Returns:
        dict: 处理后的观察字典
    """
    obs_dict = obs if isinstance(obs, dict) else obs[0]
    
    # 获取物体状态信息
    obj = env.scene.object_registry("name", "ball0")
    
    obj_state_data, obj_state_info = get_obj_state_info(obj)
    
    # 获取任务阶段信息（如果有）
    task_stage_data = th.tensor([0.0])  # 默认值
    if "reward" in obs_dict and "reward_breakdown" in obs_dict["reward"]:
        reward_breakdown = obs_dict["reward"]["reward_breakdown"]
        if "is_lifted" in reward_breakdown:
            task_stage_data = th.tensor([1.0 if reward_breakdown["is_lifted"] else 0.0])
    
    # 拼接所有状态数据
    # state_tensor = obj_state_data
    state_tensor = th.cat([obj_state_data, task_stage_data], dim=0)
    
    if 'psi' in obs_dict:
        proprio = obs_dict['psi']['proprio']
        current_obs = {
            # 'arm1_camera_rgb': process_image(obs_dict['psi']['psi:arm1_camera_rgb:Camera:0']['rgb']),
            # 'arm2_camera_rgb': process_image(obs_dict['psi']['psi:arm2_camera_rgb:Camera:0']['rgb']),
            'base_camera_rgb': process_image(obs_dict['psi']['psi:base_camera_rgb:Camera:0']['rgb']),
            'joint_qpos': proprio[:36],
            'joint_qvel': proprio[36:72], 
            'gripper_0_qpos': proprio[72:83],
            'gripper_0_qvel': proprio[83:94],
            'eef_0_pos': proprio[94:97],
            'eef_0_quat': proprio[97:101],
            'gripper_1_qpos': proprio[101:112],
            'gripper_1_qvel': proprio[112:123],
            'eef_1_pos': proprio[123:126],
            'eef_1_quat': proprio[126:130],
            'obj_state': state_tensor  # 使用前面计算的状态张量
        }

        obs_dict = {
            k: torch.stack([current_obs[k]], dim=0)  # 保持在 CPU
            for k in current_obs.keys()
        }
        
        obs_dict = {k: v.unsqueeze(0) for k, v in obs_dict.items()}
    
    return obs_dict

def convert_24d_to_34d_action(action, robot):
    """
    将24维动作转换为34维动作
    
    Args:
        action (torch.Tensor): 24维输入动作 [左臂(6), 右臂(6), 左手(6), 右手(6)]
        robot: 机器人对象，用于获取关节限制
        
    Returns:
        torch.Tensor: 34维输出动作 [左臂(6), 右臂(6), 左手(11), 右手(11)]
    """
    # 提取各部分动作
    arm_left_command = action[0:6]
    arm_right_command = action[6:12]
    hand_left_command = action[12:18]
    hand_right_command = action[18:24]
    
    # 获取关节限制
    upper_limit = robot.joint_upper_limits
    lower_limit = robot.joint_lower_limits
    
    # 初始化11维手指动作
    hand_left_command_final = th.zeros(11)
    hand_right_command_final = th.zeros(11)
    
    # 左手直接映射
    hand_left_command_final[0] = hand_left_command[0]
    hand_left_command_final[1] = hand_left_command[1]
    hand_left_command_final[3] = hand_left_command[2]
    hand_left_command_final[5] = hand_left_command[3]
    hand_left_command_final[7] = hand_left_command[4]
    hand_left_command_final[9] = hand_left_command[5]
    
    # 右手直接映射
    hand_right_command_final[0] = hand_right_command[0]
    hand_right_command_final[1] = hand_right_command[1]
    hand_right_command_final[3] = hand_right_command[2]
    hand_right_command_final[5] = hand_right_command[3]
    hand_right_command_final[7] = hand_right_command[4]
    hand_right_command_final[9] = hand_right_command[5]
    
    # 左手关节归一化映射
    # 大拇指
    norm_temp = (hand_left_command_final[1] - lower_limit[24]) / (upper_limit[24] - lower_limit[24])
    hand_left_command_final[2] = norm_temp * (upper_limit[34] - lower_limit[34]) + lower_limit[34]
    # 食指
    norm_temp = (hand_left_command_final[3] - lower_limit[15]) / (upper_limit[15] - lower_limit[15])
    hand_left_command_final[4] = norm_temp * (upper_limit[25] - lower_limit[25]) + lower_limit[25]
    # 中指
    norm_temp = (hand_left_command_final[5] - lower_limit[16]) / (upper_limit[16] - lower_limit[16])
    hand_left_command_final[6] = norm_temp * (upper_limit[26] - lower_limit[26]) + lower_limit[26]
    # 无名指
    norm_temp = (hand_left_command_final[7] - lower_limit[17]) / (upper_limit[17] - lower_limit[17])
    hand_left_command_final[8] = norm_temp * (upper_limit[27] - lower_limit[27]) + lower_limit[27]
    # 小拇指
    norm_temp = (hand_left_command_final[9] - lower_limit[18]) / (upper_limit[18] - lower_limit[18])
    hand_left_command_final[10] = norm_temp * (upper_limit[28] - lower_limit[28]) + lower_limit[28]
    
    # 右手关节归一化映射
    # 食指
    norm_temp = (hand_right_command_final[1] - lower_limit[20]) / (upper_limit[20] - lower_limit[20])
    hand_right_command_final[2] = norm_temp * (upper_limit[30] - lower_limit[30]) + lower_limit[30]
    # 中指
    norm_temp = (hand_right_command_final[3] - lower_limit[21]) / (upper_limit[21] - lower_limit[21])
    hand_right_command_final[4] = norm_temp * (upper_limit[31] - lower_limit[31]) + lower_limit[31]
    # 无名指
    norm_temp = (hand_right_command_final[5] - lower_limit[22]) / (upper_limit[22] - lower_limit[22])
    hand_right_command_final[6] = norm_temp * (upper_limit[32] - lower_limit[32]) + lower_limit[32]
    # 小拇指
    norm_temp = (hand_right_command_final[7] - lower_limit[23]) / (upper_limit[23] - lower_limit[23])
    hand_right_command_final[8] = norm_temp * (upper_limit[33] - lower_limit[33]) + lower_limit[33]
    # 大拇指
    norm_temp = (hand_right_command_final[9] - lower_limit[29]) / (upper_limit[29] - lower_limit[29])
    hand_right_command_final[10] = norm_temp * (upper_limit[35] - lower_limit[35]) + lower_limit[35]
    
    # 组合所有动作
    return th.cat((arm_left_command.to(device), arm_right_command.to(device), hand_left_command_final.to(device), hand_right_command_final.to(device)), 0)

def extract_proprio_obj_state(env, obs):
    """从观测字典中提取proprioception信息和obj state信息"""
    obs_dict = obs if isinstance(obs, dict) else obs[0]
    proprio_full = obs_dict['psi']['proprio']
    
    # 获取物体状态信息
    obj = env.scene.object_registry("name", "ball0")
    
    obj_state_data, obj_state_info = get_obj_state_info(obj)
    
    # 获取任务阶段信息（如果有）
    task_stage_data = th.tensor([0.0])  # 默认值
    if "reward" in obs_dict and "reward_breakdown" in obs_dict["reward"]:
        reward_breakdown = obs_dict["reward"]["reward_breakdown"]
        if "stage" in reward_breakdown:
            task_stage_data = th.tensor([(reward_breakdown["stage"]-1)*1.0])
    
    # 拼接所有状态数据
    state_tensor = th.cat([obj_state_data, task_stage_data], dim=0)
    
    # 拼接所有状态数据 - 包括双臂数据
    proprio_obj_state = th.cat([
        proprio_full[:36],     # joint_qpos
        proprio_full[36:72],   # joint_qvel
        proprio_full[72:83],   # gripper_0_qpos
        proprio_full[83:94],   # gripper_0_qvel
        proprio_full[94:97],   # eef_0_pos
        proprio_full[97:101],  # eef_0_quat
        proprio_full[101:112], # gripper_1_qpos
        proprio_full[112:123], # gripper_1_qvel
        proprio_full[123:126], # eef_1_pos
        proprio_full[126:130], # eef_1_quat
        state_tensor
    ])
    
    return proprio_obj_state

def generate_action(x, y, z):
    import omnigibson.utils.transform_utils as T
    """
    Generate a no-op action that will keep the robot still but aim to move the arms to the saved pose targets.

    Args:
        robot: The robot object to control.
        arm_targets: A dictionary mapping arm names to their target positions and orientations.

    Returns:
        th.Tensor: Action array for the robot.
    """
    arm_left_command = th.tensor([x, y, z, -0.579228, 0.4055798, -0.579228, 0.4055798])
    arm_right_command = th.tensor([x, -y, z, -0.579228, 0.4055798, -0.579228, 0.4055798])
    arm_left_angleaxis = T.quat2axisangle(arm_left_command[3:][[1,2,3,0]])
    arm_right_angleaxis = T.quat2axisangle(arm_right_command[3:][[1,2,3,0]])

    arm_left_action  = th.cat((arm_left_command[0:3],arm_left_angleaxis),0)
    arm_right_action  = th.cat((arm_right_command[0:3],arm_right_angleaxis),0)

    hand_left_command = th.tensor([1.57,0.64,0.04,3.11,1.57,3.07,1.57,3.08,1.57,3.05,1.57])
    hand_right_command = th.tensor([1.57,0.64,0.04,3.11,1.57,3.07,1.57,3.08,1.57,3.05,1.57])
    hand_action  = th.cat((hand_left_command,hand_right_command),0).cpu()
    action = th.cat((arm_left_action,arm_right_action,hand_action),0)
    return action

def convert_to_serializable(obj):
    """将不可序列化的对象转换为可序列化的格式"""
    if isinstance(obj, torch.Tensor):
        return obj.tolist()  # 将tensor转换为list
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_to_serializable(item) for item in obj)
    return obj

def load_domain_cfg(args):
    with open(args.paths.env_augmentor_cfg, "r") as f:
        domain_cfg = yaml.load(f, Loader=yaml.FullLoader)
    return domain_cfg

def load_default_obj(default_obj_config_dir_path):
    return [os.path.join(default_obj_config_dir_path, f) for f in os.listdir(default_obj_config_dir_path) if f.endswith(".yaml") or f.endswith(".yml")]

def initialize_training(args):
    # Set up logging path and directories
    now = datetime.now().strftime("%y%m%d-%H%M%S")
    tag = f"{now}_{args.seed}"
    if args.exp_name:
        tag += f"_{args.exp_name}"
    log_name = os.path.join(args.env_id, args.algo_name, tag)
    log_path = os.path.join(args.output_dir, log_name)
    results_dir = setup_logging_dir(args)
    print(f"Saving results to: {results_dir}")

    # Initialize wandb if tracking enabled
    if args.track:
        
        wandb.init(
            # mode="offline",
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=log_name.replace(os.path.sep, "__"),
            monitor_gym=True,
            save_code=True,
        )

    # Set up tensorboard and save args
    writer = SummaryWriter(log_path)
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    with open(f'{log_path}/args.json', 'w') as f:
        json.dump(convert_to_json_serializable(args), f, indent=4)

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initialize environment and policies
    # exp_dir = up(up(args.base_policy_ckpt))
    envs = load_environment(args)
    default_obj_config_paths = load_default_obj(args.paths.default_obj_config_dir_path)
    env_augmentor = setup_env_augmentor(default_obj_config_paths[0])
    
    base_policy = load_diffusion_policy_model(args.base_policy_ckpt)
    base_policy.eval()
    base_policy.requires_grad_(False)

    # Define network dimensions
    proprio_obj_state_dim = (36 + 36 + 11 + 11 + 3 + 4 + 11 + 11 + 3 + 4 + 13 + 1)  # joint_qpos + joint_qvel + gripper_0_qpos + gripper_0_qvel + eef_0_pos + eef_0_quat + gripper_1_qpos + gripper_1_qvel + eef_1_pos + eef_1_quat + obj_state + stage
    action_dim = 34  # 机器人的动作维度
    action_dim_for_critic = action_dim * 2 if args.critic_input == 'concat' else action_dim

    # Initialize networks
    res_actor = Actor(envs, proprio_obj_state_dim, action_dim, args).to(device)
    qf1 = SoftQNetwork(proprio_obj_state_dim, action_dim_for_critic).to(device)
    qf2 = SoftQNetwork(proprio_obj_state_dim, action_dim_for_critic).to(device)
    qf1_target = SoftQNetwork(proprio_obj_state_dim, action_dim_for_critic).to(device)
    qf2_target = SoftQNetwork(proprio_obj_state_dim, action_dim_for_critic).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())

    # Initialize optimizers
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
    actor_optimizer = optim.Adam(list(res_actor.parameters()), lr=args.policy_lr)

    # Setup entropy tuning
    if args.autotune:
        target_entropy = -action_dim
        log_sac_alpha = torch.zeros(1, requires_grad=True, device=device)
        sac_alpha = log_sac_alpha.exp().item()
        a_optimizer = optim.Adam([log_sac_alpha], lr=args.q_lr)
    else:
        sac_alpha = args.sac_alpha
        log_sac_alpha, a_optimizer = None, None

    # Initialize replay buffer
    rb = ReplayBuffer(
        args.buffer_size,
        observation_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(proprio_obj_state_dim,)),
        action_space=gym.spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(action_dim_for_critic * 3 if args.critic_input != 'res' else action_dim,)
        ),
        device=device,
        handle_timeout_termination=False,
    )

    # Save domain info
    domain_info = get_domain_info(args)
    domain_info_path = os.path.join(results_dir, "domain_info.json")
    with open(domain_info_path, "w") as f:
        json.dump(convert_to_serializable(domain_info), f, indent=4)

    # Initialize training variables
    global_step = 0
    global_update = 0
    learning_has_started = False
    num_updates_per_training = int(args.training_freq * args.utd)
    result = defaultdict(list)
    timer = NonOverlappingTimeProfiler()
    
    ball = envs.scene.object_registry("name", "ball0")
    
    episode_count = 0
    episode_successes = []
    episode_log = []
    res_action_usage = {'total_steps': 0, 'res_steps': 0}
    last_saved_count = 0

    return (envs, base_policy, res_actor, qf1, qf2, qf1_target, qf2_target, 
            q_optimizer, actor_optimizer, log_sac_alpha, a_optimizer, sac_alpha,
            rb, writer, timer, global_step, global_update, learning_has_started,
            num_updates_per_training, result, ball, episode_count,
            episode_successes, episode_log, res_action_usage, last_saved_count,
            results_dir,env_augmentor,device,action_dim,default_obj_config_paths,target_entropy,log_path)

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

def load_checkpoint(path, res_actor, qf1, qf2, qf1_target, qf2_target, log_sac_alpha, 
                   actor_optimizer, q_optimizer, a_optimizer=None, policy_lr=1.0e-4, q_lr=1.0e-4,
                   load_global_step=True):
    """
    加载检查点，恢复模型和优化器状态，但分别重置策略和Q网络的学习率
    
    Args:
        path: 检查点路径
        res_actor: 残差策略网络
        qf1, qf2: Q网络
        qf1_target, qf2_target: 目标Q网络
        log_sac_alpha: 熵系数对数
        actor_optimizer: 策略优化器
        q_optimizer: Q网络优化器
        a_optimizer: 熵系数优化器（如果使用自动调整）
        policy_lr: 策略网络学习率
        q_lr: Q网络学习率
        load_global_step: 是否加载检查点中的全局步数
    """
    if os.path.exists(path):
        print(f"正在加载检查点: {path}")
        checkpoint = torch.load(path)
        
        # 加载模型权重
        res_actor.load_state_dict(checkpoint['res_actor'])
        if 'current_res_scale' in checkpoint:
            res_actor.current_res_scale = checkpoint['current_res_scale']
            print(f"加载残差比例: {res_actor.current_res_scale}")
        # 加载Q网络权重（主网络和目标网络）
        if 'q1' in checkpoint:
            qf1.load_state_dict(checkpoint['q1'])
        if 'q2' in checkpoint:
            qf2.load_state_dict(checkpoint['q2'])
        
        # 加载目标网络
        if 'q1_target' in checkpoint:
            qf1_target.load_state_dict(checkpoint['q1_target'])
        else:
            qf1_target.load_state_dict(checkpoint['q1'])
            
        if 'q2_target' in checkpoint:
            qf2_target.load_state_dict(checkpoint['q2_target'])
        else:
            qf2_target.load_state_dict(checkpoint['q2'])
        
        # 加载熵系数
        if 'log_sac_alpha' in checkpoint and log_sac_alpha is not None:
            log_sac_alpha.data = checkpoint['log_sac_alpha']
            # print(f"log_sac_alpha: {log_sac_alpha.data}")
        
        # 恢复优化器状态并设置各自的学习率
        if 'actor_optimizer' in checkpoint:
            actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
            for param_group in actor_optimizer.param_groups:
                param_group['lr'] = policy_lr * 0.5  # 使用策略学习率的一半
                print(f"policy_lr: {param_group['lr']}")
        
        if 'q_optimizer' in checkpoint:
            q_optimizer.load_state_dict(checkpoint['q_optimizer'])
            for param_group in q_optimizer.param_groups:
                param_group['lr'] = q_lr * 0.5  # 使用Q网络学习率的一半
                print(f"q_lr: {param_group['lr']}")
        
        if a_optimizer is not None and 'a_optimizer' in checkpoint:
            a_optimizer.load_state_dict(checkpoint['a_optimizer'])
            for param_group in a_optimizer.param_groups:
                param_group['lr'] = q_lr * 0.1  # 熵系数通常使用Q网络学习率的十分之一
        
        # 获取检查点中的全局步数
        checkpoint_global_step = checkpoint.get('global_step', 0)
        
        if load_global_step:
            global_step = checkpoint_global_step
            print(f"加载成功! 当前步数: {global_step}")
            return True, global_step
        else:
            print(f"加载成功! 原始步数: {checkpoint_global_step}，重置为0")
            return True, 0
    
    print(f"检查点不存在: {path}")
    return False, 0


# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            layer_init(nn.Linear(256, 1), std=0.01),
        )

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        return self.net(x)
LOG_STD_MAX = 2
LOG_STD_MIN = -20
def get_action_space_boundaries(env):
    """
    获取双臂系统的动作空间边界
    输出动作是34维,顺序：[左臂EEF(6), 右臂EEF(6), 左手关节(11), 右手关节(11)]
    返回动作缩放和偏移
    """
    # 从环境获取所有关节的限制
    joint_upper_limits = env.scene.robots[0].joint_upper_limits
    joint_lower_limits = env.scene.robots[0].joint_lower_limits
    
    # 创建34维的动作空间
    action_dim = 34
    l = np.zeros(action_dim)
    h = np.zeros(action_dim)
    
    # === 左臂EEF (索引0-5) ===
    # 位置部分 (0-2): 范围为-1到1
    l[0:3] = -1.0
    h[0:3] = 1.0
    
    # 方向部分 (3-5): 范围为-1.5到1.5
    l[3:6] = -1.5
    h[3:6] = 1.5
    
    # === 右臂EEF (索引6-11) ===
    # 位置部分 (6-8): 范围为-1到1
    l[6:9] = -1.0
    h[6:9] = 1.0
    
    # 方向部分 (9-11): 范围为-1.5到1.5
    l[9:12] = -1.5
    h[9:12] = 1.5
    
    # === 左手关节 (索引12-22) ===
    # 使用实际机器人关节限制
    # 左手关节对应于joint_upper_limits中的索引14-18和24-28 (再加上34)
    left_hand_joint_indices = [14, 15, 16, 17, 18, 24, 25, 26, 27, 28, 34]
    for i, joint_idx in enumerate(left_hand_joint_indices):
        l[12+i] = joint_lower_limits[joint_idx]
        h[12+i] = joint_upper_limits[joint_idx]
    
    # === 右手关节 (索引23-33) ===
    # 使用实际机器人关节限制
    # 右手关节对应于joint_upper_limits中的索引19-23和29-33 (再加上35)
    right_hand_joint_indices = [19, 20, 21, 22, 23, 29, 30, 31, 32, 33, 35]
    for i, joint_idx in enumerate(right_hand_joint_indices):
        l[23+i] = joint_lower_limits[joint_idx]
        h[23+i] = joint_upper_limits[joint_idx]
    
    # 转换为PyTorch张量
    l_tensor = torch.tensor(l, dtype=torch.float32)
    h_tensor = torch.tensor(h, dtype=torch.float32)

    # 计算action_scale和action_bias (用于将[-1,1]范围的动作转换到实际范围)
    action_scale = (h_tensor - l_tensor) / 2.0
    action_bias = (h_tensor + l_tensor) / 2.0
    
    return action_scale, action_bias

class Actor(nn.Module):
    def __init__(self, env, obs_dim, action_dim, args):
        super().__init__()
        input_dim = obs_dim if args.actor_input == 'obs' else obs_dim + action_dim #69
        # print("input_dim:",input_dim)
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        # 使用明确的动作维度
        self.fc_mean = layer_init(nn.Linear(256, action_dim), std=0.01)
        self.fc_logstd = layer_init(nn.Linear(256, action_dim), std=0.01)
        
        # 是否使用动作空间映射
        self.use_action_scaling = getattr(args, 'use_action_scaling', True)
        
        # 如果使用动作空间映射，计算action_scale和action_bias
        if self.use_action_scaling:
            action_scale, action_bias = get_action_space_boundaries(env)
            self.register_buffer("action_scale", action_scale)
            self.register_buffer("action_bias", action_bias)

    def forward(self, x):
        x = self.backbone(x) 
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        return mean, log_std

    def get_action(self, x, deterministic=False):
        # 获取均值和对数标准差
        mean, log_std = self(x)
        
        # 计算标准差
        std = log_std.exp()
        
        # 创建正态分布
        normal = torch.distributions.Normal(mean, std)
        
        if deterministic:
            x_t = mean
        else:
            # 使用重参数化技巧采样
            x_t = normal.rsample()
        
        # 应用tanh激活函数
        y_t = torch.tanh(x_t)
        
        # 应用动作缩放(如果启用)
        if self.use_action_scaling:
            action = y_t * self.action_scale + self.action_bias
            # 计算log_prob
            log_prob = normal.log_prob(x_t)
            # 确保分母不为零
            denom = self.action_scale * (1 - y_t.pow(2) + 1e-6)
            log_prob -= torch.log(denom)
        else:
            action = y_t
            log_prob = normal.log_prob(x_t)
            # 确保分母不为零
            log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        
        # 合并所有维度的log_prob
        log_prob = log_prob.sum(1, keepdim=True)
        
        # 处理均值输出(用于确定性动作)
        mean_tanh = torch.tanh(mean)
        if self.use_action_scaling:
            mean_action = mean_tanh * self.action_scale + self.action_bias
        else:
            mean_action = mean_tanh
        
        return action, log_prob, mean_action

    def get_eval_action(self, x):
        """用于评估的确定性动作"""
        x = self.backbone(x)
        mean = self.fc_mean(x)
        mean = torch.tanh(mean)
        
        # 根据设置决定是否应用动作空间映射
        if self.use_action_scaling:
            action = mean * self.action_scale + self.action_bias
        else:
            action = mean
            
        return action

    def to(self, device):
        return super().to(device)

domain_cfg = None
def get_domain_info(args):
    """读取并保存域信息和训练参数"""
    domain_info = {}    # 读取 RL 训练参数
    with open(RL_TRAIN_CONFIG_PATH, "r") as f:  # 保留这个常量路径
        rl_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['rl_train_config'] = rl_config
    
    return domain_info

def setup_logging_dir(args):
    """创建基于时间戳的结果目录"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(args.paths.results_base, timestamp)
    os.makedirs(run_dir, exist_ok=True)
    return run_dir

def setup_env_augmentor(default_obj_config_path):
    augmentor = EnvAugmentor(default_obj_config_path=default_obj_config_path, 
                             default_texture_path=f"{project_path}/Benchmark_results/assets/texture/texture_d0/20241126-171022.jpg")
    return augmentor
def convert_to_json_serializable(obj):
    """将对象转换为可JSON序列化的格式"""
    if isinstance(obj, dict):
        return {k: convert_to_json_serializable(v) for k, v in obj.items()}
    elif hasattr(obj, '__dict__'):
        return {k: convert_to_json_serializable(v) for k, v in vars(obj).items()
                if not k.startswith('_')}
    elif isinstance(obj, (list, tuple)):
        return [convert_to_json_serializable(x) for x in obj]
    elif isinstance(obj, (int, float, str, bool, type(None))):
        return obj
    else:
        return str(obj) 

def parse_args():
    """Load and parse configuration from YAML file and command line arguments"""
    import argparse
    import yaml
    from pathlib import Path
    from types import SimpleNamespace

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default=RL_TRAIN_CONFIG_PATH,
                       help="path to config file")
    parser.add_argument("--exp-name", type=str, help="override experiment name")
    parser.add_argument("--seed", type=int, help="override random seed")
    parser.add_argument("--resume-path", type=str, help="path to checkpoint to resume training from")
    args_cmd = parser.parse_args()

    # Load config file
    with open(args_cmd.config, 'r') as f:
        args_yaml = yaml.safe_load(f)

    # 将paths配置转换为SimpleNamespace对象以便于访问
    paths_dict = args_yaml.pop('paths', {})
    args_yaml['paths'] = SimpleNamespace(**paths_dict)

    # Convert yaml to argparse namespace
    args = argparse.Namespace(**args_yaml)

    # Override with command line arguments if provided
    if args_cmd.exp_name is not None:
        args.exp_name = args_cmd.exp_name
    if args_cmd.seed is not None:
        args.seed = args_cmd.seed
    if args_cmd.resume_path is not None:
        args.resume_path = args_cmd.resume_path

    # Post-processing
    if args.buffer_size is None:
        args.buffer_size = args.total_timesteps
    args.buffer_size = min(args.total_timesteps, args.buffer_size)
    args.num_eval_envs = min(args.num_eval_envs, args.num_eval_episodes)
    
    # Validations
    assert args.num_eval_episodes % args.num_eval_envs == 0
    assert args.training_freq % args.num_envs == 0
    assert (args.training_freq * args.utd).is_integer()
    
    return args

###############ours####################
def load_diffusion_policy_model(checkpoint_path):
    import ssl
    ssl._create_default_https_context = ssl._create_unverified_context
    
    import requests
    from requests.packages.urllib3.exceptions import InsecureRequestWarning
    requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
    
    # 修改 huggingface_hub 的设置
    import os
    os.environ['HF_HUB_DISABLE_SSL_VERIFICATION'] = '1'
    ckpt_path = checkpoint_path
    payload = torch.load(open(ckpt_path, 'rb'), pickle_module=dill)
    cfg = payload['cfg']
    # 加载并初始化模型
    cls = hydra.utils.get_class(cfg._target_)
    workspace = cls(cfg)
    workspace: BaseWorkspace
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)

    if 'diffusion' in cfg.name:
        # diffusion model
        policy: BaseImagePolicy
        policy = workspace.model
        # if cfg.training.use_ema:
        #     policy = workspace.ema_model
        device = torch.device('cuda')
        policy.eval().to(device)
        # set inference params
        policy.num_inference_steps =  4 #16 # DDIM inference iterations 
        policy.n_action_steps = 8    
    return policy



if __name__ == "__main__":
    args = parse_args()
    #初始化训练参数
    (envs, base_policy, res_actor, qf1, qf2, qf1_target, qf2_target,
     q_optimizer, actor_optimizer, log_sac_alpha, a_optimizer, sac_alpha,
     rb, writer, timer, global_step, global_update, learning_has_started,
     num_updates_per_training, result, ball, episode_count,
     episode_successes, episode_log, res_action_usage, last_saved_count,
     results_dir,env_augmentor,device,action_dim,default_obj_config_paths,target_entropy,log_path) = initialize_training(args)

    # 加载检查点（如果有）
    if args.resume_path:
        success, loaded_global_step = load_checkpoint(
            args.resume_path,
            res_actor, qf1, qf2, qf1_target, qf2_target,
            log_sac_alpha if args.autotune else None,
            actor_optimizer, q_optimizer,
            a_optimizer if args.autotune else None,
            policy_lr=args.policy_lr,
            q_lr=args.q_lr,
            load_global_step=args.load_global_step  # 使用命令行参数控制是否加载全局步数
        )
        if success:
            global_step = loaded_global_step
            print(f"成功从步骤 {global_step} 恢复训练")

    ################################## 主代码 #########################################################
    index = 0

    obs = envs.get_obs()
    while global_step < args.total_timesteps: 

        max_reward_in_episode = -10
        max_reward_in_episode_ls = []

        # 从环境中收集样本
        for local_step in tqdm(range(args.training_freq)): 
            global_step += 1
            res_action_usage['total_steps'] += 1
            obs_dict = preprocess_observation(obs, envs)
            with torch.no_grad():
                base_act_seq = base_policy.predict_action(obs_dict)['action'] # (B, act_horizon, act_dim)
                base_action = base_act_seq.squeeze(0)[0] #(-1, total_act_dim)
                base_action = convert_24d_to_34d_action(base_action, envs.scene.robots[0])
                base_act_seq = base_action
            # 渐进式探索
            res_ratio = min(global_step / args.prog_explore, 1) 
            enable_res = np.random.rand() < res_ratio
            if(global_step<=args.prog_explore_th):
                enable_res=False
            # ALGO LOGIC: put action logic here
            # 获取残差动作
            proprio_obj_state = extract_proprio_obj_state(envs,obs)
            proprio_obj_state_tensor = torch.FloatTensor(proprio_obj_state).unsqueeze(0).to(device)

            if not learning_has_started:
                res_action = np.random.uniform(-1, 1, size=action_dim) 
                res_action = np.zeros_like(res_action)
            else:
                actor_input = proprio_obj_state_tensor if args.actor_input == 'obs' else \
                    torch.cat([proprio_obj_state_tensor, torch.FloatTensor(base_action).unsqueeze(0).to(device)], dim=1)
                res_action, _, _ = res_actor.get_action(actor_input)
                res_action = res_action.detach().cpu().numpy()
            if not enable_res:
                res_action = np.zeros_like(res_action)

            # 组合基准动作和残差动作
            act_dim=34
            res_act_seq = res_action.reshape(-1, act_dim)
            scaled_res_seq = args.res_scale * res_act_seq # (B, act_horizon, act_dim)
            # 将scaled_res_seq转换为tensor
            scaled_res_seq = torch.FloatTensor(scaled_res_seq).to(device)
            # print("scaled_res_seq:",scaled_res_seq.shape)
            if enable_res and learning_has_started:
                res_action_usage['res_steps'] += 1
                # print("global_step:",global_step)
                # print("base_act_seq:",base_action)
                if global_step % 10000 == 0:
                    print("scaled_res_seq:",scaled_res_seq)

            print(scaled_res_seq[0])
            final_act_seq = base_act_seq.detach().cpu().numpy() + scaled_res_seq[0].detach().cpu().numpy()
            #final_act_seq = final_act_seq[0, 0, :]  # Get first action [18]

            # TRY NOT TO MODIFY: execute the game and log data.
            # 执行动作并获取环境反馈
            #final_act_seq = final_act_seq.detach().cpu().numpy()
            next_obs_seq, rewards, terminations, truncations, infos = envs.step(final_act_seq)
            next_obs_seq = envs.get_obs()
            
            rewards = rewards - 4.0 # negative reward + bootstrap at truncated yields best results
            # TRY NOT TO MODIFY: record rewards for plotting purposes
            # 记录训练数据
            result = collect_episode_info(infos, result)

            
            # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
            # 处理经验回放缓冲区的数据存储
            fail= envs._check_failure()
            if fail:
                truncations=True
            episode_success = terminations
            if args.bootstrap_at_done == 'never':
                stop_bootstrap = truncations | terminations # always stop bootstrap when episode ends
            else:
                if args.bootstrap_at_done == 'always':
                    need_final_obs = truncations | terminations # always need final obs when episode ends
                    stop_bootstrap = np.zeros_like(terminations, dtype=bool) # never stop bootstrap
                else: # bootstrap at truncated
                    need_final_obs = truncations & (~terminations) # only need final obs when truncated and not terminated
                    stop_bootstrap = terminations # only stop bootstrap when terminated, don't stop when truncated
            
            # 准备要存储的动作数据
            if args.critic_input == 'res' and args.actor_input == 'obs':
                actions_to_save = res_action
            else:  # sum or concat both need base actions for s and s'
                next_obs_dict = preprocess_observation(next_obs_seq, envs)
                with torch.no_grad():
                    base_next_act_seq = base_policy.predict_action(next_obs_dict)['action']
                    base_next_action = base_next_act_seq.squeeze(0)[0].cpu().numpy()
                base_action = base_action.cpu().numpy()
                res_action = res_action.reshape(1, -1) 
                base_action = base_action.reshape(1, -1)  # 变成 (1, action_dim)
                base_next_action = convert_24d_to_34d_action(th.from_numpy(base_next_action).to(device), envs.scene.robots[0]).cpu().numpy()
                base_next_action = base_next_action.reshape(1, -1)  # 变成 (1, action_dim)
                # print()
                # input()
                actions_to_save = np.concatenate([res_action, base_action, base_next_action])
            # 将数据添加到回放缓冲区
            next_proprio_obj_state = extract_proprio_obj_state(envs,next_obs_seq)
            ###改为proprio+obj state
            rb.add(proprio_obj_state, next_proprio_obj_state, actions_to_save, rewards, stop_bootstrap, infos)
            # print("")
            episode_success = terminations
            # terminations=True
            if terminations or truncations:
                max_reward_in_episode_ls.append(rewards)
                episode_count += 1
                # eval_episode_count += 1
                episode_successes.append(float(episode_success))
                if terminations:
                    serializable_log = convert_to_serializable(env_augmentor.log)
                    episode_log.append((serializable_log, "success"))
                else:  # truncations
                    serializable_log = convert_to_serializable(env_augmentor.log)
                    episode_log.append((serializable_log, "failure"))
                
                  # 记录本episode的成功情况
                # eval_episodes.append(float(episode_success))
                print("############### 重置环境！################")
                env_reset(envs,index % len(default_obj_config_paths))
                index += 1
                obs = envs.get_obs()
            else:
                obs = next_obs_seq


        timer.end('collect')
        if episode_count % args.save_interval == 0 and episode_count > 0 and episode_count != last_saved_count:
            last_saved_count = episode_count 
            # 计算最近n个episodes的成功率
            recent_success_rate = float(np.mean(episode_successes[-args.save_interval:]))
            # 计算并记录残差动作使用率
            res_usage_ratio = res_action_usage['res_steps'] / res_action_usage['total_steps']
            log_filename = os.path.join(results_dir, f"step_{global_step}.log")
            with open(log_filename, "w") as f:
                f.write("Training Results:\n")
                f.write(f"Global Step: {global_step}\n")
                f.write(f"Episode Count: {episode_count}\n")
                f.write(f"Success Rate: {recent_success_rate*100:.1f}% (averaged over last {args.save_interval} episodes)\n")
                f.write(f"Res Action Usage Ratio: {res_usage_ratio*100:.1f}%\n")
                # f.write(f"Current Res Scale: {current_res_scale:.4f}\n")
                f.write("\nEpisode Logs:\n")
                for log in episode_log[-args.save_interval:]:
                    f.write(f"{str(log)}\n")
            print(f"\nSaved results at episode {episode_count} (step {global_step})")
            print(f"Recent Success Rate: {recent_success_rate*100:.1f}%")
            print(f"max_reward_in_episode: {max_reward_in_episode}")

            # =============== 记录 wandb ===============
            wandb.log({
                    "eval/last_reward": np.mean(max_reward_in_episode_ls),
                    "eval/SR": recent_success_rate,
                    "actor/actor_loss": actor_loss.item(),
                    "actor/sac_alpha": sac_alpha, 
                    "actor/actor_grad_norm": actor_grad_norm.item(),
                    "critic/qf1_loss": qf1_loss.item(),
                    "critic/qf2_loss": qf2_loss.item(),
                    "critic/qf_loss": qf_loss.item() / 2.0,
                    "critic/qf1_grad_norm": qf1_grad_norm.item(),
                    "critic/qf2_grad_norm": qf2_grad_norm.item(),
                
                })
            max_reward_in_episode_ls = []
            max_reward_in_episode = -10
            # =============== 记录 wandb ===============


        # ALGO LOGIC: training.
        # 训练开始
        if rb.size() * rb.n_envs < args.learning_starts: #1000步以后开始训练 
            # print("continue collect")
            continue
        else:
            # print("start train")
            learning_has_started = True
        ################################## 训练 #########################################################
        for local_update in range(num_updates_per_training):
            # num_updates_per_training = 16 int(args.training_freq * args.utd) #16
            global_update += 1
            data = rb.sample(args.batch_size) # batch_size=1024 
            # 准备训练数据
            if args.critic_input != 'res' or args.actor_input == 'obs_base_action':
                # 从存储的动作中分离出残差动作和基础动作
                total_act_dim=34
                res_action = data.actions[:, :total_act_dim]
                base_actions = data.actions[:, total_act_dim:total_act_dim*2]
                base_next_actions = data.actions[:, -total_act_dim:].cpu()
            else:
                res_action = data.actions
            #############################################
            # Train agent
            # 训练智能体
            #############################################
            # update the value networks
            # 更新值网络        
            with torch.no_grad():
                # 确保输入数据在GPU上
                if args.actor_input == 'obs':
                    actor_input = data.next_observations
                else:
                    base_next_actions = base_next_actions.to(device)  # 移到GPU
                    actor_input = torch.cat([data.next_observations, base_next_actions], dim=1)
                next_state_res_actions, next_state_log_pi, _ = res_actor.get_action(actor_input)
                # 根据不同模式组合动作
                if args.critic_input == 'res':
                    next_state_actions = next_state_res_actions
                elif args.critic_input == 'sum':
                    scaled_res_actions = (args.res_scale * next_state_res_actions).to(device)
                    base_next_actions = base_next_actions.to(device)
                    next_state_actions = base_next_actions + scaled_res_actions
                else:  # concat
                    base_next_actions = base_next_actions.to(device)
                    next_state_actions = torch.cat([next_state_res_actions, base_next_actions], dim=1)

                # 计算目标Q值
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - sac_alpha * next_state_log_pi
                
                # 确保rewards和dones在GPU上
                rewards = data.rewards.to(device)
                dones = data.dones.to(device)
                # 检查维度
                next_q_value = rewards.flatten() + (1 - dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)
                next_q_value = next_q_value.detach()


            # 计算当前Q值
            if args.critic_input == 'res':
                current_actions = res_action.to(device)
            elif args.critic_input == 'sum':
                scaled_res_actions = (args.res_scale * res_action).to(device)
                base_actions = base_actions.to(device)
                current_actions = base_actions + scaled_res_actions
            else:  # concat
                res_action = res_action.to(device)
                base_actions = base_actions.to(device)
                current_actions = torch.cat([res_action, base_actions], dim=1)
            
            # 更新Q网络
            observations = data.observations.to(device)
            qf1_a_values = qf1(observations, current_actions).view(-1)
            qf2_a_values = qf2(observations, current_actions).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss

            q_optimizer.zero_grad()
            qf_loss.backward()
            qf1_grad_norm = nn.utils.clip_grad_norm_(qf1.parameters(), args.max_grad_norm)
            qf2_grad_norm = nn.utils.clip_grad_norm_(qf2.parameters(), args.max_grad_norm)
            q_optimizer.step()

            # update the policy network
            # 更新策略网络
            if global_update % args.policy_frequency == 0:
                if args.actor_input == 'obs':
                    actor_input = observations
                else:
                    base_actions = base_actions.to(device)
                    actor_input = torch.cat([observations, base_actions], dim=1)

                res_pi, log_pi, _ = res_actor.get_action(actor_input)
                # res_pi: 残差动作
                # log_pi: 动作的log概率，用于计算熵 
                if args.critic_input == 'res':
                    pi = res_pi
                elif args.critic_input == 'sum':
                    scaled_res_actions = (args.res_scale * res_pi).to(device)
                    base_actions = base_actions.to(device)
                    pi = base_actions + scaled_res_actions
                else:  # concat
                    base_actions = base_actions.to(device)
                    pi = torch.cat([res_pi, base_actions], dim=1)
                # 1. 计算当前策略下的Q值
                qf1_pi = qf1(observations, pi)
                qf2_pi = qf2(observations, pi)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                actor_loss = ((sac_alpha * log_pi) - min_qf_pi).mean()
                # sac_alpha * log_pi: 熵正则项，鼓励探索
                # -min_qf_pi: 负的Q值，我们要最大化Q值
                # 整体是最小化 熵正则项 - Q值
                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(res_actor.parameters(), args.max_grad_norm)
                actor_optimizer.step()
            
                # 自动调整熵正则化系数
                if args.autotune:
                    with torch.no_grad():
                        _, log_pi, _ = res_actor.get_action(actor_input)
                    sac_alpha_loss = (-log_sac_alpha * (log_pi + target_entropy)).mean()

                    a_optimizer.zero_grad()
                    sac_alpha_loss.backward()
                    a_optimizer.step()
                    sac_alpha = log_sac_alpha.exp().item()

            # update the target networks
            # 更新目标网络
            if global_update % args.target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
        
        timer.end('train')
        # 记录训练相关数据 ，每100步记录一次log信息
        if (global_step - args.training_freq) // args.log_freq < global_step // args.log_freq:
            if len(result['return']) > 0:
                for k, v in result.items():
                    writer.add_scalar(f"train/{k}", np.mean(v), global_step)
                result = defaultdict(list)
            writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
            writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
            writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
            writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
            writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
            writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
            writer.add_scalar("losses/sac_alpha", sac_alpha, global_step)
            writer.add_scalar("losses/qf1_grad_norm", qf1_grad_norm.item(), global_step)
            writer.add_scalar("losses/qf2_grad_norm", qf2_grad_norm.item(), global_step)
            writer.add_scalar("losses/actor_grad_norm", actor_grad_norm.item(), global_step)
            # print("SPS:", int(global_step / (time.time() - start_time)))
            timer.dump_to_writer(writer, global_step)
            if args.autotune:
                writer.add_scalar("losses/sac_alpha_loss", sac_alpha_loss.item(), global_step)
                            # 记录最大reward
        # 保存检查点
        if args.save_freq and ( global_step >= args.total_timesteps or \
                (global_step - args.training_freq) // args.save_freq < global_step // args.save_freq):
            os.makedirs(f'{log_path}/checkpoints', exist_ok=True)
            print("############ 保存检查点！##############")
            torch.save({
                'res_actor': res_actor.state_dict(),
                'q1': qf1.state_dict(),
                'q2': qf2.state_dict(),
                'q1_target': qf1_target.state_dict(),
                'q2_target': qf2_target.state_dict(),
                'log_sac_alpha': log_sac_alpha if args.autotune else np.log(args.sac_alpha),
                # 保存优化器状态
                'actor_optimizer': actor_optimizer.state_dict(),
                'q_optimizer': q_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict() if args.autotune else None,
                'global_step': global_step,
            }, f'{log_path}/checkpoints/{global_step}.pt')

    envs.close()
    writer.close()
    envs.close()
    writer.close()
